import numpy as np
from python_ai.common.xcommon import *
import tensorflow as tf

np.set_printoptions(edgeitems=200)

n_samples = 8
m, n = 4, 4  # ori 13, 13
box_per_cell = 3 # ori 5
n_cls = 5  # ori 20

loss = tf.ones([n_samples, m, n, box_per_cell, 5+n_cls])
print(loss)
print(loss.shape)

sum = tf.reduce_sum(loss, axis=[1, 2, 3, 4])
print(sum)
print(sum.shape)
